
import sys
sys.path.append('./')
from utils.result_preprocess import OS_GENESIS_PREPROCESS

class OS_Genesis_Agent:
    def __init__(self, device, accelerator, cache_dir='~/.cache', dropout=0.5, policy_lm=None):
        self.model = None
        self.policy_lm = policy_lm
        self.device = device
        self.accelerator = accelerator

        if "OS-Genesis-7B-AC" in self.policy_lm:
             from gui_speaker.models.QwenVL import Qwen2_VL_Agent
             self.agent = Qwen2_VL_Agent(device=device, accelerator=accelerator, policy_lm=self.policy_lm)
        else:
             from gui_speaker.models.InternVL import InternVL_Agent
             self.agent = InternVL_Agent(device=device, accelerator=accelerator, policy_lm=self.policy_lm)
        self.res_pre_process = self._res_pre_process()

    def _res_pre_process(self):
        return OS_GENESIS_PREPROCESS()
    
    def _load_model(self):
        self.agent.model = self.agent._load_model()

    def get_action(self, obs, args):
        if "4B" in self.policy_lm or "8B" in self.policy_lm:
            return self.agent.get_action(obs, args)
        else:
            obs['messages'] = [
                {
                    "role": "user",
                    "content": obs['question']
                },
                {
                    "role": "assistant",
                    "content": None
                }
            ]
            return self.agent.get_action(obs, args)

